"""
负责保存模型
"""
from Config.Config import *
import os
import torch

def CreatRun(idx,dis):
    Save_Path_Root = SaveWeights + "\\train"+str(dis) + "\\epx"+str(idx)
    if(not os.path.exists(Save_Path_Root)):
        os.makedirs(Save_Path_Root)

        return Save_Path_Root
    else:
        return CreatRun(idx+1,dis)


def Save_Model(EXP_Path,weight_best,weight_last):

    if(os.path.exists(EXP_Path)):

        save_path_root = EXP_Path+"\\weights"
        if(not os.path.exists(save_path_root)):
            os.makedirs(save_path_root)

        save_path_best = save_path_root+"\\best.pth"
        save_path_last = save_path_root+"\\last.pth"
        torch.save(weight_best,save_path_best)
        torch.save(weight_last,save_path_last)
        print()
        print("best:",save_path_best)
        print("last:",save_path_last)

    else:
        raise Exception("保存地址异常")


if __name__ == '__main__':
    Save_Model(0,'1',"best")